Synthetic models for posterior distributions¶

Marco Raveri (marco.raveri@unige.it), Cyrille Doux (doux@lpsc.in2p3.fr), Shivam Pandey (shivampcosmo@gmail.com)

In this notebook we show how to build normalizing flow syntetic models for posterior distributions, as in Raveri, Doux and Pandey (2024), arXiv:2409.09101.

Table of contents¶

  1. Notebook setup
  2. Base example
  3. Average flow example
  4. Caching flows
  5. Real application: joint parameter estimates
  6. Real application: accurate likelihood values
  7. Advanced Topic: Spline flows

Notebook setup:¶

In [1]:
# Show plots inline, and load main getdist plot module and samples class
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

# import libraries:
import sys, os
os.environ['TF_USE_LEGACY_KERAS'] = '1'  # needed for tensorflow KERAS compatibility
os.environ['DISPLAY'] = 'inline'  # hack to get getdist working
sys.path.insert(0,os.path.realpath(os.path.join(os.getcwd(),'../..')))
from getdist import plots, MCSamples
from getdist.gaussian_mixtures import GaussianND
import getdist
getdist.chains.print_load_details = False
import scipy
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# tensorflow imports:
import tensorflow as tf
import tensorflow_probability as tfp

# import the tensiometer tools that we need:
import tensiometer
from tensiometer.utilities import stats_utilities as utilities
from tensiometer.synthetic_probability import synthetic_probability as sp

# getdist settings to ensure consistency of plots:
getdist_settings = {'ignore_rows': 0.0, 
                    'smooth_scale_2D': 0.3,
                    'smooth_scale_1D': 0.3,
                    }    
2024-12-12 00:51:19.103387: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

We start by building a random Gaussian mixture that we are going to use for tests:

In [2]:
# define the parameters of the problem:
dim = 6
num_gaussians = 3
num_samples = 10000

# we seed the random number generator to get reproducible results:
seed = 100
np.random.seed(seed)
# we define the range for the means and covariances:
mean_range = (-0.5, 0.5)
cov_scale = 0.4**2
# means and covs:
means = np.random.uniform(mean_range[0], mean_range[1], num_gaussians*dim).reshape(num_gaussians, dim)
weights = np.random.rand(num_gaussians)
weights = weights / np.sum(weights)
covs = [cov_scale*utilities.vector_to_PDM(np.random.rand(int(dim*(dim+1)/2))) for _ in range(num_gaussians)]

# cast to required precision:
means = means.astype(np.float32)
weights = weights.astype(np.float32)
covs = [cov.astype(np.float32) for cov in covs]

# initialize distribution:
distribution = tfp.distributions.Mixture(
    cat=tfp.distributions.Categorical(probs=weights),
    components=[
        tfp.distributions.MultivariateNormalTriL(loc=_m, scale_tril=tf.linalg.cholesky(_c))
        for _m, _c in zip(means, covs)
    ], name='Mixture')

# sample the distribution:
samples = distribution.sample(num_samples).numpy()
# calculate log posteriors:
logP = distribution.log_prob(samples).numpy()

# create MCSamples from the samples:
chain = MCSamples(samples=samples, 
                    settings=getdist_settings,
                    loglikes=-logP,
                    name_tag='Mixture',
                    )

# we make a sanity check plot:
g = plots.get_subplot_plotter()
g.triangle_plot(chain, filled=True)
    
No description has been provided for this image

Base example:¶

We train a normalizing flow on samples of a given distribution.

We initialize and train the normalizing flow on samples of the distribution we have just defined:

In [3]:
kwargs = {
          'feedback': 2,
          'plot_every': 1000,
          'pop_size': 1,
          #'cache_dir': 'test',  # set this to a directory to cache the results
          #'root_name': 'test',  # sets the name of the flow for the cache files
        }

flow = sp.flow_from_chain(chain,  # parameter difference chain
                          **kwargs)
* Initializing samples
    - flow name: Mixture_flow
    - precision: <dtype: 'float32'>
    - flow parameters and ranges:
      param1 : [-1.36545, 1.17324]
      param2 : [-1.3803, 1.0535]
      param3 : [-1.26807, 0.802192]
      param4 : [-0.836911, 1.33515]
      param5 : [-1.69169, 1.38861]
      param6 : [-1.5119, 0.921801]
    - periodic parameters: []
    - time taken: 0.0012 seconds
* Initializing fixed bijector
    - using prior bijector: ranges
    - rescaling samples
    - time taken: 0.1434 seconds
* Initializing trainable bijector
    Building Autoregressive Flow
    - # parameters          : 6
    - periodic parameters   : None
    - # transformations     : 8
    - hidden_units          : [12, 12]
    - transformation_type   : affine
    - autoregressive_type   : masked
    - permutations          : True
    - scale_roto_shift      : False
    - activation            : <function asinh at 0x7fef80ac0430>
    - time taken: 0.9316 seconds
* Initializing training dataset
    - 9000/1000 training/test samples and uniform weights
    - time taken: 0.9942 seconds
* Initializing transformed distribution
    - time taken: 0.0072 seconds
* Initializing loss function
    - using standard loss function
    - time taken: 0.0000 seconds
* Initializing training model
    - Compiling model
    - time taken: 0.0190 seconds
    - trainable parameters : 3168
    - maximum learning rate: 0.001
    - minimum learning rate: 1e-06
    - time taken: 1.3315 seconds
* Training
    - Compiling model
    - time taken: 0.0082 seconds
Epoch 1/100
20/20 - 5s - loss: 8.5297 - val_loss: 8.3729 - lr: 0.0010 - 5s/epoch - 270ms/step
Epoch 2/100
20/20 - 0s - loss: 8.4937 - val_loss: 8.3418 - lr: 0.0010 - 307ms/epoch - 15ms/step
Epoch 3/100
20/20 - 0s - loss: 8.4587 - val_loss: 8.3074 - lr: 0.0010 - 307ms/epoch - 15ms/step
Epoch 4/100
20/20 - 0s - loss: 8.4216 - val_loss: 8.2677 - lr: 0.0010 - 316ms/epoch - 16ms/step
Epoch 5/100
20/20 - 0s - loss: 8.3880 - val_loss: 8.2379 - lr: 0.0010 - 312ms/epoch - 16ms/step
Epoch 6/100
20/20 - 0s - loss: 8.3654 - val_loss: 8.2126 - lr: 0.0010 - 312ms/epoch - 16ms/step
Epoch 7/100
20/20 - 0s - loss: 8.3508 - val_loss: 8.1953 - lr: 0.0010 - 304ms/epoch - 15ms/step
Epoch 8/100
20/20 - 0s - loss: 8.3414 - val_loss: 8.1867 - lr: 0.0010 - 309ms/epoch - 15ms/step
Epoch 9/100
20/20 - 0s - loss: 8.3381 - val_loss: 8.1799 - lr: 0.0010 - 299ms/epoch - 15ms/step
Epoch 10/100
20/20 - 0s - loss: 8.3317 - val_loss: 8.1758 - lr: 0.0010 - 303ms/epoch - 15ms/step
Epoch 11/100
20/20 - 0s - loss: 8.3311 - val_loss: 8.1647 - lr: 0.0010 - 327ms/epoch - 16ms/step
Epoch 12/100
20/20 - 0s - loss: 8.3258 - val_loss: 8.1629 - lr: 0.0010 - 327ms/epoch - 16ms/step
Epoch 13/100
20/20 - 0s - loss: 8.3183 - val_loss: 8.1585 - lr: 0.0010 - 303ms/epoch - 15ms/step
Epoch 14/100
20/20 - 0s - loss: 8.3105 - val_loss: 8.1495 - lr: 0.0010 - 307ms/epoch - 15ms/step
Epoch 15/100
20/20 - 0s - loss: 8.3008 - val_loss: 8.1391 - lr: 0.0010 - 308ms/epoch - 15ms/step
Epoch 16/100
20/20 - 0s - loss: 8.2913 - val_loss: 8.1269 - lr: 0.0010 - 298ms/epoch - 15ms/step
Epoch 17/100
20/20 - 0s - loss: 8.2770 - val_loss: 8.1120 - lr: 0.0010 - 300ms/epoch - 15ms/step
Epoch 18/100
20/20 - 0s - loss: 8.2609 - val_loss: 8.0965 - lr: 0.0010 - 310ms/epoch - 16ms/step
Epoch 19/100
20/20 - 0s - loss: 8.2405 - val_loss: 8.0676 - lr: 0.0010 - 317ms/epoch - 16ms/step
Epoch 20/100
20/20 - 0s - loss: 8.2186 - val_loss: 8.0489 - lr: 0.0010 - 318ms/epoch - 16ms/step
Epoch 21/100
20/20 - 0s - loss: 8.1949 - val_loss: 8.0162 - lr: 0.0010 - 314ms/epoch - 16ms/step
Epoch 22/100
20/20 - 0s - loss: 8.1620 - val_loss: 7.9790 - lr: 0.0010 - 301ms/epoch - 15ms/step
Epoch 23/100
20/20 - 0s - loss: 8.1266 - val_loss: 7.9393 - lr: 0.0010 - 297ms/epoch - 15ms/step
Epoch 24/100
20/20 - 0s - loss: 8.0787 - val_loss: 7.8851 - lr: 0.0010 - 309ms/epoch - 15ms/step
Epoch 25/100
20/20 - 0s - loss: 8.0226 - val_loss: 7.8159 - lr: 0.0010 - 307ms/epoch - 15ms/step
Epoch 26/100
20/20 - 0s - loss: 7.9619 - val_loss: 7.7690 - lr: 0.0010 - 307ms/epoch - 15ms/step
Epoch 27/100
20/20 - 0s - loss: 7.9079 - val_loss: 7.7077 - lr: 0.0010 - 295ms/epoch - 15ms/step
Epoch 28/100
20/20 - 0s - loss: 7.8647 - val_loss: 7.6794 - lr: 0.0010 - 313ms/epoch - 16ms/step
Epoch 29/100
20/20 - 0s - loss: 7.8373 - val_loss: 7.6537 - lr: 0.0010 - 310ms/epoch - 16ms/step
Epoch 30/100
20/20 - 0s - loss: 7.8197 - val_loss: 7.6377 - lr: 0.0010 - 326ms/epoch - 16ms/step
Epoch 31/100
20/20 - 0s - loss: 7.7959 - val_loss: 7.6192 - lr: 0.0010 - 316ms/epoch - 16ms/step
Epoch 32/100
20/20 - 0s - loss: 7.7838 - val_loss: 7.6061 - lr: 0.0010 - 331ms/epoch - 17ms/step
Epoch 33/100
20/20 - 0s - loss: 7.7721 - val_loss: 7.5902 - lr: 0.0010 - 316ms/epoch - 16ms/step
Epoch 34/100
20/20 - 0s - loss: 7.7503 - val_loss: 7.5721 - lr: 0.0010 - 315ms/epoch - 16ms/step
Epoch 35/100
20/20 - 0s - loss: 7.7354 - val_loss: 7.5597 - lr: 0.0010 - 316ms/epoch - 16ms/step
Epoch 36/100
20/20 - 0s - loss: 7.7214 - val_loss: 7.5452 - lr: 0.0010 - 335ms/epoch - 17ms/step
Epoch 37/100
20/20 - 0s - loss: 7.7083 - val_loss: 7.5404 - lr: 0.0010 - 327ms/epoch - 16ms/step
Epoch 38/100
20/20 - 0s - loss: 7.6962 - val_loss: 7.5227 - lr: 0.0010 - 306ms/epoch - 15ms/step
Epoch 39/100
20/20 - 0s - loss: 7.6788 - val_loss: 7.5071 - lr: 0.0010 - 322ms/epoch - 16ms/step
Epoch 40/100
20/20 - 0s - loss: 7.6684 - val_loss: 7.5048 - lr: 0.0010 - 346ms/epoch - 17ms/step
Epoch 41/100
20/20 - 0s - loss: 7.6565 - val_loss: 7.4894 - lr: 0.0010 - 319ms/epoch - 16ms/step
Epoch 42/100
20/20 - 0s - loss: 7.6448 - val_loss: 7.4862 - lr: 0.0010 - 343ms/epoch - 17ms/step
Epoch 43/100
20/20 - 0s - loss: 7.6327 - val_loss: 7.4720 - lr: 0.0010 - 322ms/epoch - 16ms/step
Epoch 44/100
20/20 - 0s - loss: 7.6192 - val_loss: 7.4685 - lr: 0.0010 - 314ms/epoch - 16ms/step
Epoch 45/100
20/20 - 0s - loss: 7.6117 - val_loss: 7.4548 - lr: 0.0010 - 307ms/epoch - 15ms/step
Epoch 46/100
20/20 - 0s - loss: 7.6005 - val_loss: 7.4462 - lr: 0.0010 - 308ms/epoch - 15ms/step
Epoch 47/100
20/20 - 0s - loss: 7.5932 - val_loss: 7.4393 - lr: 0.0010 - 310ms/epoch - 16ms/step
Epoch 48/100
20/20 - 0s - loss: 7.5838 - val_loss: 7.4267 - lr: 0.0010 - 329ms/epoch - 16ms/step
Epoch 49/100
20/20 - 0s - loss: 7.5788 - val_loss: 7.4313 - lr: 0.0010 - 370ms/epoch - 18ms/step
Epoch 50/100
20/20 - 0s - loss: 7.5665 - val_loss: 7.4233 - lr: 0.0010 - 337ms/epoch - 17ms/step
Epoch 51/100
20/20 - 0s - loss: 7.5601 - val_loss: 7.4141 - lr: 0.0010 - 353ms/epoch - 18ms/step
Epoch 52/100
20/20 - 0s - loss: 7.5558 - val_loss: 7.4035 - lr: 0.0010 - 313ms/epoch - 16ms/step
Epoch 53/100
20/20 - 0s - loss: 7.5517 - val_loss: 7.4010 - lr: 0.0010 - 306ms/epoch - 15ms/step
Epoch 54/100
20/20 - 0s - loss: 7.5406 - val_loss: 7.4062 - lr: 0.0010 - 307ms/epoch - 15ms/step
Epoch 55/100
20/20 - 0s - loss: 7.5377 - val_loss: 7.3904 - lr: 0.0010 - 312ms/epoch - 16ms/step
Epoch 56/100
20/20 - 0s - loss: 7.5336 - val_loss: 7.3831 - lr: 0.0010 - 333ms/epoch - 17ms/step
Epoch 57/100
20/20 - 0s - loss: 7.5237 - val_loss: 7.3812 - lr: 0.0010 - 307ms/epoch - 15ms/step
Epoch 58/100
20/20 - 0s - loss: 7.5198 - val_loss: 7.3821 - lr: 0.0010 - 309ms/epoch - 15ms/step
Epoch 59/100
20/20 - 0s - loss: 7.5160 - val_loss: 7.3711 - lr: 0.0010 - 318ms/epoch - 16ms/step
Epoch 60/100
20/20 - 0s - loss: 7.5117 - val_loss: 7.3675 - lr: 0.0010 - 303ms/epoch - 15ms/step
Epoch 61/100
20/20 - 0s - loss: 7.5058 - val_loss: 7.3667 - lr: 0.0010 - 300ms/epoch - 15ms/step
Epoch 62/100
20/20 - 0s - loss: 7.5047 - val_loss: 7.3640 - lr: 0.0010 - 314ms/epoch - 16ms/step
Epoch 63/100
20/20 - 0s - loss: 7.5023 - val_loss: 7.3607 - lr: 0.0010 - 323ms/epoch - 16ms/step
Epoch 64/100
20/20 - 0s - loss: 7.4954 - val_loss: 7.3648 - lr: 0.0010 - 301ms/epoch - 15ms/step
Epoch 65/100
20/20 - 0s - loss: 7.4930 - val_loss: 7.3561 - lr: 0.0010 - 328ms/epoch - 16ms/step
Epoch 66/100
20/20 - 0s - loss: 7.4901 - val_loss: 7.3516 - lr: 0.0010 - 301ms/epoch - 15ms/step
Epoch 67/100
20/20 - 0s - loss: 7.4838 - val_loss: 7.3529 - lr: 0.0010 - 312ms/epoch - 16ms/step
Epoch 68/100
20/20 - 0s - loss: 7.4805 - val_loss: 7.3450 - lr: 0.0010 - 301ms/epoch - 15ms/step
Epoch 69/100
20/20 - 0s - loss: 7.4765 - val_loss: 7.3463 - lr: 0.0010 - 318ms/epoch - 16ms/step
Epoch 70/100
20/20 - 0s - loss: 7.4747 - val_loss: 7.3366 - lr: 0.0010 - 318ms/epoch - 16ms/step
Epoch 71/100
20/20 - 0s - loss: 7.4676 - val_loss: 7.3356 - lr: 0.0010 - 328ms/epoch - 16ms/step
Epoch 72/100
20/20 - 0s - loss: 7.4687 - val_loss: 7.3439 - lr: 0.0010 - 319ms/epoch - 16ms/step
Epoch 73/100
20/20 - 0s - loss: 7.4662 - val_loss: 7.3325 - lr: 0.0010 - 301ms/epoch - 15ms/step
Epoch 74/100
20/20 - 0s - loss: 7.4617 - val_loss: 7.3267 - lr: 0.0010 - 305ms/epoch - 15ms/step
Epoch 75/100
20/20 - 0s - loss: 7.4611 - val_loss: 7.3250 - lr: 0.0010 - 311ms/epoch - 16ms/step
Epoch 76/100
20/20 - 0s - loss: 7.4556 - val_loss: 7.3389 - lr: 0.0010 - 318ms/epoch - 16ms/step
Epoch 77/100
20/20 - 0s - loss: 7.4537 - val_loss: 7.3287 - lr: 0.0010 - 337ms/epoch - 17ms/step
Epoch 78/100
20/20 - 0s - loss: 7.4550 - val_loss: 7.3195 - lr: 0.0010 - 338ms/epoch - 17ms/step
Epoch 79/100
20/20 - 0s - loss: 7.4504 - val_loss: 7.3163 - lr: 0.0010 - 319ms/epoch - 16ms/step
Epoch 80/100
20/20 - 0s - loss: 7.4488 - val_loss: 7.3202 - lr: 0.0010 - 321ms/epoch - 16ms/step
Epoch 81/100
20/20 - 0s - loss: 7.4419 - val_loss: 7.3131 - lr: 0.0010 - 305ms/epoch - 15ms/step
Epoch 82/100
20/20 - 0s - loss: 7.4430 - val_loss: 7.3122 - lr: 0.0010 - 309ms/epoch - 15ms/step
Epoch 83/100
20/20 - 0s - loss: 7.4434 - val_loss: 7.3090 - lr: 0.0010 - 318ms/epoch - 16ms/step
Epoch 84/100
20/20 - 0s - loss: 7.4367 - val_loss: 7.3171 - lr: 0.0010 - 319ms/epoch - 16ms/step
Epoch 85/100
20/20 - 0s - loss: 7.4339 - val_loss: 7.3046 - lr: 0.0010 - 321ms/epoch - 16ms/step
Epoch 86/100
20/20 - 0s - loss: 7.4314 - val_loss: 7.3124 - lr: 0.0010 - 331ms/epoch - 17ms/step
Epoch 87/100
20/20 - 0s - loss: 7.4348 - val_loss: 7.3042 - lr: 0.0010 - 315ms/epoch - 16ms/step
Epoch 88/100
20/20 - 0s - loss: 7.4295 - val_loss: 7.3046 - lr: 0.0010 - 359ms/epoch - 18ms/step
Epoch 89/100
20/20 - 0s - loss: 7.4254 - val_loss: 7.2951 - lr: 0.0010 - 312ms/epoch - 16ms/step
Epoch 90/100
20/20 - 0s - loss: 7.4228 - val_loss: 7.2994 - lr: 0.0010 - 372ms/epoch - 19ms/step
Epoch 91/100
20/20 - 0s - loss: 7.4203 - val_loss: 7.2986 - lr: 0.0010 - 322ms/epoch - 16ms/step
Epoch 92/100
20/20 - 0s - loss: 7.4219 - val_loss: 7.2872 - lr: 0.0010 - 339ms/epoch - 17ms/step
Epoch 93/100
20/20 - 0s - loss: 7.4182 - val_loss: 7.2864 - lr: 0.0010 - 328ms/epoch - 16ms/step
Epoch 94/100
20/20 - 0s - loss: 7.4169 - val_loss: 7.2878 - lr: 0.0010 - 323ms/epoch - 16ms/step
Epoch 95/100
20/20 - 0s - loss: 7.4195 - val_loss: 7.2853 - lr: 0.0010 - 323ms/epoch - 16ms/step
Epoch 96/100
20/20 - 0s - loss: 7.4145 - val_loss: 7.2902 - lr: 0.0010 - 310ms/epoch - 16ms/step
Epoch 97/100
20/20 - 0s - loss: 7.4111 - val_loss: 7.2857 - lr: 0.0010 - 316ms/epoch - 16ms/step
Epoch 98/100
20/20 - 0s - loss: 7.4114 - val_loss: 7.2810 - lr: 0.0010 - 356ms/epoch - 18ms/step
Epoch 99/100
20/20 - 0s - loss: 7.4080 - val_loss: 7.2838 - lr: 0.0010 - 308ms/epoch - 15ms/step
Epoch 100/100
20/20 - 0s - loss: 7.4063 - val_loss: 7.2757 - lr: 0.0010 - 321ms/epoch - 16ms/step
* Population optimizer:
    - best model is number 1
    - best loss function is 7.41
    - best validation loss function is 7.28
    - population losses [7.28]
In [4]:
# we can plot training summaries to make sure training went smoothly:
flow.training_plot()
No description has been provided for this image
In [5]:
# and we can print the training summary:
flow.print_training_summary()
loss         : 7.4063
val_loss     : 7.2757
lr           : 0.0010
chi2Z_ks     : 0.0694
chi2Z_ks_p   : 1.2405e-04
loss_rate    : -0.0017
val_loss_rate: -0.0082
In [6]:
# we can triangle plot the flow to see how well it has learned the target distribution:
g = plots.get_subplot_plotter()
g.triangle_plot([chain, flow.MCSamples(20000)], 
                params=flow.param_names,
                filled=True)
No description has been provided for this image
In [7]:
# this looks nice but not perfect, let's train for longer:
flow.feedback = 1
flow.train(epochs=300, verbose=-1);  # verbose = -1 uses tqdm progress bar
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
In [8]:
# we can plot training summaries to make sure training went smoothly:
flow.training_plot()
<Figure size 640x480 with 0 Axes>
No description has been provided for this image

If you train for long enough you should start seeing the learning rate adapting to the non-improving (noisy) loss function.

This means that the flow is learning finer and finer features and a good indication that training is converging. If you push it further, at some point, the flow will start overfitting and training will stop.

Now let's look at how the marginal distributions look like:

In [9]:
# we can triangle plot the flow to see how well it has learned the target distribution:
g = plots.get_subplot_plotter()
g.triangle_plot([chain, 
                 flow.MCSamples(20000)  # this flow method returns a MCSamples object
                 ], 
                params=flow.param_names,
                filled=True)
No description has been provided for this image

This is now much better!

We can use the trained flow to perform several operations. For example let's compute log-likelihoods

In [10]:
samples = flow.MCSamples(20000)
logP = flow.log_probability(flow.cast(samples.samples)).numpy()
samples.addDerived(logP, name='logP', label='\\log P')
samples.updateBaseStatistics();

# now let's plot everything:
g = plots.get_subplot_plotter()
g.triangle_plot([samples, chain], 
                plot_3d_with_param='logP',
                filled=False)
No description has been provided for this image

We can appreciate here a beautiful display of a projection effect. The marginal distribution of $p_5$ is peaked at a positive value while the logP plot clearly shows that the peak of the full distribution is the negative one.

If you are interested in understanding systematically these types of effect, check the corresponding tensiometer tutorial!

Average flow example:¶

A more advanced flow model consists in training several flows and using a weighted mixture normalizing flow model.

This flow model improves the variance of the flow in regions that are scarse with samples (as different flow models will allucinate differently)...

Let's try averaging 5 flow models (note that we could do this in parallel with MPI on bigger machines):

In [11]:
kwargs = {
          'feedback': 1,
          'verbose': -1,
          'plot_every': 1000,
          'pop_size': 1,
          'num_flows': 5,
          'epochs': 400,
        }

average_flow = sp.average_flow_from_chain(chain,  # parameter difference chain
                                          **kwargs)
Training flow 0
0epoch [00:00, ?epoch/s]
Training flow 1
0epoch [00:00, ?epoch/s]
Training flow 2
0epoch [00:00, ?epoch/s]
Training flow 3
0epoch [00:00, ?epoch/s]
Training flow 4
0epoch [00:00, ?epoch/s]
In [12]:
# most methods are implemented for the average flow as well:
average_flow.training_plot()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [13]:
# and we can print the training summary, which in this case contains more info:
average_flow.print_training_summary()
Number of flows: 5
Flow weights   : [0.2  0.2  0.21 0.2  0.2 ]
loss         : [7.25 7.23 7.25 7.26 7.25]
val_loss     : [7.28 7.26 7.22 7.27 7.28]
lr           : [1.00e-05 1.00e-04 3.16e-05 1.00e-05 3.16e-05]
chi2Z_ks     : [0.05 0.04 0.03 0.04 0.03]
chi2Z_ks_p   : [0.02 0.05 0.19 0.09 0.48]
loss_rate    : [ 9.06e-06  2.40e-04 -4.67e-04  2.29e-05 -1.10e-04]
val_loss_rate: [-8.58e-05  0.00e+00 -9.16e-04 -9.73e-05  2.85e-04]
In [14]:
avg_samples = average_flow.MCSamples(20000)
avg_samples.name_tag = 'Average Flow'
temp_samples = [_f.MCSamples(20000) for _f in average_flow.flows]
for i, _s in enumerate(temp_samples):
    _s.name_tag = _s.name_tag + f'_{i}'
# let's plot the flows:
g = plots.get_subplot_plotter()
g.triangle_plot([chain, avg_samples] + temp_samples,
                filled=False)
WARNING:tensorflow:5 out of the last 8 calls to <function FlowCallback.log_probability at 0x7fede0139940> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:5 out of the last 8 calls to <function FlowCallback.log_probability at 0x7fede0139940> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 9 calls to <function FlowCallback.log_probability at 0x7fed8070dd30> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 9 calls to <function FlowCallback.log_probability at 0x7fed8070dd30> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
No description has been provided for this image
In [15]:
logP = average_flow.log_probability(average_flow.cast(avg_samples.samples)).numpy()
avg_samples.addDerived(logP, name='logP', label='\\log P')
avg_samples.updateBaseStatistics();

# now let's plot everything:
g = plots.get_subplot_plotter()
g.triangle_plot([avg_samples, chain], 
                plot_3d_with_param='logP',
                filled=False)
No description has been provided for this image

Caching flows¶

In this section we discuss how to effectively cache flows to avoid retraining them.

The main obstacle to doing this easily is tensorflow, that can not pickle models... This means the flow needs to be rebuilt (althought not re-trained) every time we want to use it.

We implement two caching strategies depending on the use case.

In [16]:
# we create a temporary directory to store the caches:
import tempfile
temp_dir = tempfile.TemporaryDirectory()
print(f'Temporary directory created at: {temp_dir.name}')
Temporary directory created at: /tmp/tmprd2sr44d

The first caching strategy can be used when the cached flow has to be used in place, in the same script.

To cache then give the flow constructor a cache_dir argument and a root_name argument. The flow will save several files in the cache_dir directory with the root_name prefix.

In [17]:
kwargs = {
          'feedback': 1,
          'verbose': -1,
          'plot_every': 1000,
          'epochs': 10,
          'pop_size': 1,
          'cache_dir': temp_dir.name,  # set this to a directory to cache the results
          'root_name': 'test_1',  # sets the name of the flow for the cache files
        }

cached_flow_1 = sp.flow_from_chain(chain, **kwargs)
* Initializing samples
    - time taken: 0.0011 seconds
* Initializing fixed bijector
    - time taken: 0.1388 seconds
* Initializing trainable bijector
    - time taken: 0.9645 seconds
* Initializing training dataset
    - time taken: 1.0166 seconds
* Initializing transformed distribution
    - time taken: 0.0069 seconds
* Initializing loss function
    - time taken: 0.0000 seconds
* Initializing training model
    - Compiling model
    - time taken: 0.0126 seconds
    - time taken: 1.3153 seconds
* Training
    - Compiling model
    - time taken: 0.0082 seconds
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
In [18]:
# retrive the flow by executing the same call as before, note that the flow will not re-train this time:
cached_flow_1_test = sp.flow_from_chain(chain, **kwargs)
* Initializing samples
    - time taken: 0.0012 seconds
* Initializing fixed bijector
    - time taken: 0.1397 seconds
* Initializing trainable bijector
    - time taken: 0.9445 seconds
* Initializing training dataset
    - time taken: 0.9974 seconds
* Initializing transformed distribution
    - time taken: 0.0070 seconds
* Initializing loss function
    - time taken: 0.0000 seconds
In [19]:
# let's check that the cache worked:
_temp_samples = cached_flow_1.sample(10)
_temp_like_1 = cached_flow_1.log_probability(_temp_samples).numpy()
_temp_like_2 = cached_flow_1_test.log_probability(_temp_samples).numpy()
print('All likelihoods are close?', np.allclose(_temp_like_1, _temp_like_2))
All likelihoods are close? True

The second caching strategy can be used when the cached flow has to be used in different scripts.

In this use case we probably don't have access to all the data and settings to rebuild the flow. So, when building the flow we instruct it to save to cache all the information needed to rebuild it. This is disk memory (relatively) expensive because it will save a copy of the training set and the settings. MCMC chains are typically not too big, but their disk space might add up... This is the reason why this caching strategy is slightly more involved...

In [20]:
# first we need to import appropriate utilities:
from tensiometer.utilities import caching

# then we need to decorate the method we want to cache (so that this works for all types of flows)
cached_flow_from_chain = caching.cache_input(sp.flow_from_chain)

# then we can call the decorated method as before:
kwargs = {
          'feedback': 1,
          'verbose': -1,
          'plot_every': 1000,
          'epochs': 10,
          'pop_size': 1,
          'cache_dir': temp_dir.name,  # set this to a directory to cache the results
          'root_name': 'test_2',  # sets the name of the flow for the cache files
        }

cached_flow_2 = cached_flow_from_chain(chain, **kwargs)
* Initializing samples
    - time taken: 0.0015 seconds
* Initializing fixed bijector
    - time taken: 0.1523 seconds
* Initializing trainable bijector
    - time taken: 0.9265 seconds
* Initializing training dataset
    - time taken: 0.9943 seconds
* Initializing transformed distribution
    - time taken: 0.0072 seconds
* Initializing loss function
    - time taken: 0.0000 seconds
* Initializing training model
    - Compiling model
    - time taken: 0.0124 seconds
    - time taken: 2.6310 seconds
* Training
    - Compiling model
    - time taken: 0.0103 seconds
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
In [21]:
# now pretend this is a different script...

# To load the flow we then only need to pass (to the decorated method) only cache_dir and root_name:

# in the new script we re-decorate the method:
other_cached_flow_from_chain = caching.cache_input(sp.flow_from_chain)

# then we can call the decorated method passing only the cache details:
kwargs = {
          'cache_dir': temp_dir.name,  # set this to a directory to cache the results
          'root_name': 'test_2',  # sets the name of the flow for the cache files
        }
cached_flow_2_test = cached_flow_from_chain(**kwargs)
* Initializing samples
    - time taken: 0.0009 seconds
* Initializing fixed bijector
    - time taken: 0.1388 seconds
* Initializing trainable bijector
    - time taken: 0.9501 seconds
* Initializing training dataset
    - time taken: 0.9972 seconds
* Initializing transformed distribution
    - time taken: 0.0069 seconds
* Initializing loss function
    - time taken: 0.0000 seconds
In [22]:
# let's check that the cache worked:
_temp_samples = cached_flow_2.sample(10)
_temp_like_1 = cached_flow_2.log_probability(_temp_samples).numpy()
_temp_like_2 = cached_flow_2_test.log_probability(_temp_samples).numpy()
print('All likelihoods are close?', np.allclose(_temp_like_1, _temp_like_2))
All likelihoods are close? True

Real world application: joint parameter estimation¶

In this example we perform a flow-based analysis of a joint posterior.

The idea is that we have posteriors samples from two independent experiments, we learn the two posteriors and then we combine them to form the joint posterior.

Note that we are assuming - as it is true in this example - that the prior is the same among the two experiments and flat (so that we are not duplicating the prior).

This procedure was used, for example, in Gatti, Campailla et al (2024), arXiv:2405.10881.

In [23]:
# we start by loading up the posteriors:

# load the samples (remove no burn in since the example chains have already been cleaned):
chains_dir = os.path.realpath(os.path.join(os.getcwd(), '../..', 'test_chains'))
# the Planck 2018 TTTEEE chain:
chain_1 = getdist.mcsamples.loadMCSamples(file_root=os.path.join(chains_dir, 'Planck18TTTEEE'), no_cache=True, settings=getdist_settings)
# the DES Y1 3x2 chain:
chain_2 = getdist.mcsamples.loadMCSamples(file_root=os.path.join(chains_dir, 'DES'), no_cache=True, settings=getdist_settings)
# the joint chain:
chain_12 = getdist.mcsamples.loadMCSamples(file_root=os.path.join(chains_dir, 'Planck18TTTEEE_DES'), no_cache=True, settings=getdist_settings)

# let's add omegab as a derived parameter:
for _ch in [chain_1, chain_2, chain_12]:
    _p = _ch.getParams()
    _h = _p.H0 / 100.
    _ch.addDerived(_p.omegabh2 / _h**2, name='omegab', label='\\Omega_b')
    _ch.updateBaseStatistics()

# we define the parameters of the problem:
param_names = ['H0', 'omegam', 'sigma8', 'ns', 'omegab']

# and then do a sanity check plot:
g = plots.get_subplot_plotter()
g.triangle_plot([chain_1, chain_2, chain_12], params=param_names, filled=True)
No description has been provided for this image
In [24]:
# we then train the flows on the base parameters that we want to combine (note that for this exercise we should include all shared parameters):
kwargs = {
          'feedback': 1,
          'verbose': -1,
          'plot_every': 1000,
          'pop_size': 1,
          'num_flows': 3,
          'epochs': 400,
        }

# actual flow training:
flow_1 = sp.average_flow_from_chain(chain_1, param_names=param_names, **kwargs)
flow_2 = sp.average_flow_from_chain(chain_2, param_names=param_names, **kwargs)
flow_12 = sp.average_flow_from_chain(chain_12, param_names=param_names, **kwargs)

# plot to make sure training went well:
flow_1.training_plot()
flow_2.training_plot()
flow_12.training_plot()
Training flow 0
0epoch [00:00, ?epoch/s]
Training flow 1
0epoch [00:00, ?epoch/s]
Training flow 2
0epoch [00:00, ?epoch/s]
Training flow 0
0epoch [00:00, ?epoch/s]
Training flow 1
0epoch [00:00, ?epoch/s]
Training flow 2
0epoch [00:00, ?epoch/s]
Training flow 0
0epoch [00:00, ?epoch/s]
Training flow 1
0epoch [00:00, ?epoch/s]
Training flow 2
0epoch [00:00, ?epoch/s]
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [25]:
# sanity check triangle plot:
g = plots.get_subplot_plotter()
g.triangle_plot([chain_1, flow_1.MCSamples(20000, settings=getdist_settings),
                 chain_2, flow_2.MCSamples(20000, settings=getdist_settings),
                 chain_12, flow_12.MCSamples(20000, settings=getdist_settings),
                 ], 
                params=param_names,
                filled=False)
# we log scale the y axis for the logP plot so that we can appreciate the accuracy of the flow on the tails:
for i in range(len(param_names)):
    _ax = g.subplots[i, i]
    _ax.set_yscale('log')
    _ax.set_ylim([1.e-5, 1.0])
    _ax.set_ylabel('$\\log P$')
    _ax.tick_params(axis='y', which='both', labelright='on')
    _ax.yaxis.set_label_position("right")    
No description has been provided for this image
In [26]:
# now we can define the joint posterior:
def joint_log_posterior(H0, omegam, sigma8, ns, omegab):
    params = [H0, omegam, sigma8, ns, omegab]
    return [flow_1.log_probability(flow_1.cast([params])).numpy()[0] + flow_2.log_probability(flow_2.cast([params])).numpy()[0]]

# and sample it:
from cobaya.run import run
from getdist.mcsamples import MCSamplesFromCobaya

parameters = {}
for key in param_names:
    parameters[key] = {"prior": {"min": 1.01*max(flow_1.parameter_ranges[key][0], flow_2.parameter_ranges[key][0]),
                                 "max": 0.99*min(flow_1.parameter_ranges[key][1], flow_2.parameter_ranges[key][1])},
                       "latex": flow_1.param_labels[flow_1.param_names.index(key)]}
info = {
    "likelihood": {"joint_log_posterior": joint_log_posterior},
    "params": parameters,
    }
In [27]:
# MCMC sample:

# we need a \sim good initial proposal and starting point, we get them from one of the flows:
flow_1_samples = flow_1.sample(10000)
flow_1_logPs = flow_1.log_probability(flow_1_samples).numpy()
flow_1_maxP_sample = flow_1_samples[np.argmax(flow_1_logPs)].numpy()

# we need a good starting point otherwise this will take long...
for _i, _k in enumerate(parameters.keys()):
    info['params'][_k]['ref'] = flow_1_maxP_sample[_i]

info["sampler"] = {"mcmc": 
                {'covmat': np.cov(flow_1_samples.numpy().T),
                 'covmat_params': param_names,
                 'max_tries': np.inf,
                 'Rminus1_stop': 0.01,
                 'learn_proposal_Rminus1_max': 30.,
                 'learn_proposal_Rminus1_max_early': 30.,
                 'measure_speeds': False,
                 'Rminus1_single_split': 10,
                 }}
info['debug'] = 100  # note this is an insane hack to disable very verbose output...
updated_info, sampler = run(info)
joint_chain = MCSamplesFromCobaya(updated_info, sampler.products()["sample"], ignore_rows=0.3, settings=getdist_settings)
In [28]:
## Nested sampling sample:
#_dim = len(flow_1.param_names)
#
#info["sampler"] = {"polychord": {'nlive': 50*_dim,
#                                 'measure_speeds': False,
#                                 'num_repeats': 2*_dim,
#                                 'nprior': 10*25*_dim,
#                                 'do_clustering': True,
#                                 'precision_criterion': 0.01,
#                                 'boost_posterior': 10, 
#                                 'feedback': 0,
#                                 },
#                    }
#info['debug'] = 100  # note this is an insane hack to disable very verbose output...
#updated_info, sampler = run(info)
#joint_chain = MCSamplesFromCobaya(updated_info, sampler.products()["sample"], settings=getdist_settings)
In [29]:
joint_chain.name_tag = 'Flow Joint'
chain_12.name_tag = 'Real Joint (Planck + DES)'

# sanity check triangle plot:
g = plots.get_subplot_plotter()
g.triangle_plot([joint_chain, chain_12], 
                params=param_names,
                filled=False)
No description has been provided for this image

As we can see this works fairly well, given that the two experiments are in some tension - do not overlap significantly.

Make sure you check for the consistency of the experiments you are combining before doing so, to ensure that the joint flow posterior samples a well-trained part of the flows.

You can check the example notebook in this documentation for how to compute tensions between two experiments.

Advanced Topic: accurate likelihood values¶

For some applications we need to push the local accuracy of the flow model. In this case we need to provide exact probability values (up to normalization constant) for the training set.

These are then used to build a part of the loss function that rewards local accuracy of probability values. This second part of the loss function is the estimated evidence error. By default the code adaptively mixes the two loss functions to find an optimal solution.

As a downside we can only train a flow that preserves all the parameters of the distribution, i.e. we cannot train on marginalized parameters (as we have done in the previous examples).

For more details see

In [30]:
ev, eer = flow.evidence()
print(f'log(Z) = {ev} +- {eer}')
log(Z) = 0.11619382351636887 +- 0.6003727316856384

We can see that the value is close to what it should be (zero since the original distribution is normalized) but the estimated error is still fairly high.

Since we have (normalized) log P values we can check the local reliability of the normalizing flow:

In [31]:
validation_flow_log10_P = flow.log_probability(flow.cast(chain.samples[flow.test_idx, :])).numpy()/np.log(10.)
validation_samples_log10_P = -chain.loglikes[flow.test_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist
training_flow_log10_P = flow.log_probability(flow.cast(chain.samples[flow.training_idx, :])).numpy()/np.log(10.)
training_samples_log10_P = -chain.loglikes[flow.training_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist

# do the plot:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

ax1.scatter(training_samples_log10_P - np.amax(training_samples_log10_P), training_flow_log10_P - training_samples_log10_P, s=0.1, label='training')
ax1.scatter(validation_samples_log10_P - np.amax(validation_samples_log10_P), validation_flow_log10_P - validation_samples_log10_P, s=0.5, label='validation')
ax1.legend()
ax1.axhline(0, color='k', linestyle='--')
ax1.set_xlabel('$\log_{10}(P_{\\rm true}/P_{\\rm max})$')
ax1.set_ylabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax1.set_ylim(-1.0, 1.0)

ax2.hist(training_flow_log10_P - training_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='training')
ax2.hist(validation_flow_log10_P - validation_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='validation')
ax2.legend()
ax2.axvline(0, color='k', linestyle='--')
ax2.set_xlabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax2.set_xlim([-1.0, 1.0])

plt.tight_layout()
plt.show()
No description has been provided for this image

We can clearly see that the local accuracy of the flow in full dimension is not high. As we move to the tails we easily have large errors. The variance of this plot is the estimated error on the evidence, which is rather large and dominated by the outliers in the tails.

Considering average flows usually improves the situation, in particular on the validation sample.

In [32]:
ev, eer = average_flow.evidence()
print(f'log(Z) = {ev} +- {eer}')
log(Z) = 0.0537954606115818 +- 0.4508938491344452
In [33]:
validation_flow_log10_P = average_flow.log_probability(average_flow.cast(chain.samples[average_flow.test_idx, :])).numpy()/np.log(10.)
validation_samples_log10_P = -chain.loglikes[average_flow.test_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist
training_flow_log10_P = average_flow.log_probability(average_flow.cast(chain.samples[average_flow.training_idx, :])).numpy()/np.log(10.)
training_samples_log10_P = -chain.loglikes[average_flow.training_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist

# do the plot:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

ax1.scatter(training_samples_log10_P - np.amax(training_samples_log10_P), training_flow_log10_P - training_samples_log10_P, s=0.1, label='training')
ax1.scatter(validation_samples_log10_P - np.amax(validation_samples_log10_P), validation_flow_log10_P - validation_samples_log10_P, s=0.5, label='validation')
ax1.legend()
ax1.axhline(0, color='k', linestyle='--')
ax1.set_xlabel('$\log_{10}(P_{\\rm true}/P_{\\rm max})$')
ax1.set_ylabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax1.set_ylim(-1.0, 1.0)

ax2.hist(training_flow_log10_P - training_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='training')
ax2.hist(validation_flow_log10_P - validation_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='validation')
ax2.legend()
ax2.axvline(0, color='k', linestyle='--')
ax2.set_xlabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax2.set_xlim([-1.0, 1.0])

plt.tight_layout()
plt.show()
No description has been provided for this image

This looks significantly better, and in fact the error on the evidence estimate is lower...

If we want to do better we need to train with evidence error loss, as discussed in the reference paper for this example notebook.

In [34]:
kwargs = {
          'feedback': 1,
          'verbose': -1,
          'plot_every': 1000,
          'pop_size': 1,
          'num_flows': 1,
          'epochs': 400,
          'loss_mode': 'softadapt',
        }

average_flow_2 = sp.average_flow_from_chain(chain,  # parameter difference chain
                                            **kwargs)
Training flow 0
0epoch [00:00, ?epoch/s]
In [35]:
average_flow_2.training_plot()
No description has been provided for this image

As we can see the training plots are substantially more complicated as we are monitoring several additional quantities.

In [36]:
ev, eer = average_flow_2.evidence()
print(f'log(Z) = {ev} +- {eer}')
log(Z) = 0.11429519206285477 +- 0.42347487807273865
In [37]:
validation_flow_log10_P = average_flow_2.log_probability(average_flow_2.cast(chain.samples[average_flow_2.test_idx, :])).numpy()/np.log(10.)
validation_samples_log10_P = -chain.loglikes[average_flow_2.test_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist
training_flow_log10_P = average_flow_2.log_probability(average_flow_2.cast(chain.samples[average_flow_2.training_idx, :])).numpy()/np.log(10.)
training_samples_log10_P = -chain.loglikes[average_flow_2.training_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist

# do the plot:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

ax1.scatter(training_samples_log10_P - np.amax(training_samples_log10_P), training_flow_log10_P - training_samples_log10_P, s=0.1, label='training')
ax1.scatter(validation_samples_log10_P - np.amax(validation_samples_log10_P), validation_flow_log10_P - validation_samples_log10_P, s=0.5, label='validation')
ax1.legend()
ax1.axhline(0, color='k', linestyle='--')
ax1.set_xlabel('$\log_{10}(P_{\\rm true}/P_{\\rm max})$')
ax1.set_ylabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax1.set_ylim(-1.0, 1.0)

ax2.hist(training_flow_log10_P - training_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='training')
ax2.hist(validation_flow_log10_P - validation_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='validation')
ax2.legend()
ax2.axvline(0, color='k', linestyle='--')
ax2.set_xlabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax2.set_xlim([-1.0, 1.0])

plt.tight_layout()
plt.show()
No description has been provided for this image

As we can see this achieves performances that are very close to averaging flows. Combining the two strategies achieves the best performances (but is slower to train).

Advanced Topic: Spline Flows¶

When more flexibility in the normalizing flow model is needed we provide an implementation of neural spline flows as discussed in Durkan et al (2019), arXiv:1906.04032.

In [38]:
kwargs = {
          # flow settings:
          'pop_size': 1,
          'num_flows': 1,
          'epochs': 400,
          'transformation_type': 'spline',
          'autoregressive_type': 'masked',
          # feedback flags:
          'feedback': 1,
          'verbose': -1,
          'plot_every': 1000,
        }

spline_flow = sp.flow_from_chain(chain,  # parameter difference chain
                                 **kwargs)
* Initializing samples
    - time taken: 0.0009 seconds
* Initializing fixed bijector
    - time taken: 0.1424 seconds
* Initializing trainable bijector
WARNING: range_max should be larger than the maximum range of the data and is beeing adjusted.
    range_max: 5.0
    max range: 9.511670112609863
    new range_max: 10.51167
    - time taken: 0.9609 seconds
* Initializing training dataset
    - time taken: 1.0433 seconds
* Initializing transformed distribution
    - time taken: 0.0073 seconds
* Initializing loss function
    - time taken: 0.0000 seconds
* Initializing training model
    - Compiling model
    - time taken: 0.0294 seconds
    - time taken: 7.0743 seconds
* Training
    - Compiling model
    - time taken: 0.0283 seconds
0epoch [00:00, ?epoch/s]
0batch [00:00, ?batch/s]
In [39]:
# we can plot training summaries to make sure training went smoothly:
spline_flow.training_plot()
No description has been provided for this image
In [40]:
# we can triangle plot the flow to see how well it has learned the target distribution:
g = plots.get_subplot_plotter()
g.triangle_plot([chain, 
                 spline_flow.MCSamples(20000)  # this flow method returns a MCSamples object
                 ], 
                params=flow.param_names,
                filled=True)
No description has been provided for this image
In [41]:
samples = spline_flow.MCSamples(20000)
logP = spline_flow.log_probability(spline_flow.cast(samples.samples)).numpy()
samples.addDerived(logP, name='logP', label='\\log P')
samples.updateBaseStatistics();

# now let's plot everything:
g = plots.get_subplot_plotter()
g.triangle_plot([samples, chain], 
                plot_3d_with_param='logP',
                filled=False)
No description has been provided for this image
In [42]:
validation_flow_log10_P = spline_flow.log_probability(spline_flow.cast(chain.samples[spline_flow.test_idx, :])).numpy()/np.log(10.)
validation_samples_log10_P = -chain.loglikes[spline_flow.test_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist
training_flow_log10_P = spline_flow.log_probability(spline_flow.cast(chain.samples[spline_flow.training_idx, :])).numpy()/np.log(10.)
training_samples_log10_P = -chain.loglikes[spline_flow.training_idx]/np.log(10.)  # notice the minus sign due to the definition of logP in getdist

ev, eer = spline_flow.evidence()
print(f'log(Z) = {ev} +- {eer}')

# do the plot:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

ax1.scatter(training_samples_log10_P - np.amax(training_samples_log10_P), training_flow_log10_P - training_samples_log10_P, s=0.1, label='training')
ax1.scatter(validation_samples_log10_P - np.amax(validation_samples_log10_P), validation_flow_log10_P - validation_samples_log10_P, s=0.5, label='validation')
ax1.legend()
ax1.axhline(0, color='k', linestyle='--')
ax1.set_xlabel('$\log_{10}(P_{\\rm true}/P_{\\rm max})$')
ax1.set_ylabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax1.set_ylim(-1.0, 1.0)

ax2.hist(training_flow_log10_P - training_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='training')
ax2.hist(validation_flow_log10_P - validation_samples_log10_P, bins=50, range=[-1., 1.], density=True, alpha=0.5, label='validation')
ax2.legend()
ax2.axvline(0, color='k', linestyle='--')
ax2.set_xlabel('$\log_{10}(P_{\\rm flow}) - \log_{10}(P_{\\rm true})$')
ax2.set_xlim([-1.0, 1.0])

plt.tight_layout()
plt.show()
log(Z) = 0.04767173156142235 +- 0.6037229299545288
No description has been provided for this image

We can check what happens across the bijector layers:

In [43]:
from tensiometer.synthetic_probability import flow_utilities as flow_utils

training_samples_spaces, validation_samples_spaces = \
    flow_utils.get_samples_bijectors(spline_flow, 
                                     feedback=True)
    
for i, _s in enumerate(training_samples_spaces):
    print('*  ', _s.name_tag)
    g = plots.get_subplot_plotter()
    g.triangle_plot([
                    training_samples_spaces[i],
                    validation_samples_spaces[i]], 
                    filled=True,
                    )
    plt.show()
0 - bijector name:  permute
1 - bijector name:  spline_flow
2 - bijector name:  permute
3 - bijector name:  spline_flow
4 - bijector name:  permute
5 - bijector name:  spline_flow
6 - bijector name:  permute
7 - bijector name:  spline_flow
8 - bijector name:  permute
9 - bijector name:  spline_flow
10 - bijector name:  permute
11 - bijector name:  spline_flow
12 - bijector name:  permute
13 - bijector name:  spline_flow
14 - bijector name:  permute
15 - bijector name:  spline_flow
*   original_space
No description has been provided for this image
*   training_space
No description has been provided for this image
*   0_after_permute
No description has been provided for this image
*   1_after_spline_flow
No description has been provided for this image
*   2_after_permute
No description has been provided for this image
*   3_after_spline_flow
No description has been provided for this image
*   4_after_permute
No description has been provided for this image
*   5_after_spline_flow
No description has been provided for this image
*   6_after_permute
No description has been provided for this image
*   7_after_spline_flow
No description has been provided for this image
*   8_after_permute
No description has been provided for this image
*   9_after_spline_flow
No description has been provided for this image
*   10_after_permute
No description has been provided for this image
*   11_after_spline_flow
No description has been provided for this image
*   12_after_permute
No description has been provided for this image
*   13_after_spline_flow
No description has been provided for this image
*   14_after_permute
No description has been provided for this image
*   15_after_spline_flow
No description has been provided for this image